{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch==2.2 in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2) (12.3.101)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2) (1.3.0)\n",
      "Requirement already satisfied: torchvision==0.17.0 in /opt/conda/lib/python3.10/site-packages (0.17.0)\n",
      "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17.0) (1.26.2)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17.0) (2.28.1)\n",
      "Requirement already satisfied: torch==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17.0) (2.2.0)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.10/site-packages (from torchvision==0.17.0) (10.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2.0->torchvision==0.17.0) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.0->torchvision==0.17.0) (12.3.101)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17.0) (2.1.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17.0) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17.0) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->torchvision==0.17.0) (2022.6.15)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2.0->torchvision==0.17.0) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2.0->torchvision==0.17.0) (1.3.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install torch==2.2\n",
    "!pip install torchvision==0.17.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define model architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.cn1 = nn.Conv2d(1, 16, 3, 1)\n",
    "        self.cn2 = nn.Conv2d(16, 32, 3, 1)\n",
    "        self.dp1 = nn.Dropout2d(0.10)\n",
    "        self.dp2 = nn.Dropout2d(0.25)\n",
    "        self.fc1 = nn.Linear(4608, 64) # 4608 is basically 12 X 12 X 32\n",
    "        self.fc2 = nn.Linear(64, 10)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = self.cn1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.cn2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dp1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dp2(x)\n",
    "        x = self.fc2(x)\n",
    "        op = F.log_softmax(x, dim=1)\n",
    "        return op"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define training and inference routines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_dataloader, optim, epoch):\n",
    "    model.train()\n",
    "    for b_i, (X, y) in enumerate(train_dataloader):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "        optim.zero_grad()\n",
    "        pred_prob = model(X)\n",
    "        loss = F.nll_loss(pred_prob, y) # nll is the negative likelihood loss\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        if b_i % 10 == 0:\n",
    "            print('epoch: {} [{}/{} ({:.0f}%)]\\t training loss: {:.6f}'.format(\n",
    "                epoch, b_i * len(X), len(train_dataloader.dataset),\n",
    "                100. * b_i / len(train_dataloader), loss.item()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_dataloader):\n",
    "    model.eval()\n",
    "    loss = 0\n",
    "    success = 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in test_dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred_prob = model(X)\n",
    "            loss += F.nll_loss(pred_prob, y, reduction='sum').item()  # loss summed across the batch\n",
    "            pred = pred_prob.argmax(dim=1, keepdim=True)  # us argmax to get the most likely prediction\n",
    "            success += pred.eq(y.view_as(pred)).sum().item()\n",
    "\n",
    "    loss /= len(test_dataloader.dataset)\n",
    "\n",
    "    print('\\nTest dataset: Overall Loss: {:.4f}, Overall Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        loss, success, len(test_dataloader.dataset),\n",
    "        100. * success / len(test_dataloader.dataset)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## create data loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The mean and standard deviation values are calculated as the mean of all pixel values of all images in the training dataset\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1302,), (0.3069,))])), # train_X.mean()/256. and train_X.std()/256.\n",
    "    batch_size=32, shuffle=True)\n",
    "\n",
    "test_dataloader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, \n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1302,), (0.3069,)) \n",
    "                   ])),\n",
    "    batch_size=500, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define optimizer and run training epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "model = ConvNet()\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py:1347: UserWarning: dropout2d: Received a 2-D input to dropout2d, which is deprecated and will result in an error in a future release. To retain the behavior and silence this warning, please use dropout instead. Note that dropout2d exists to provide channel-wise dropout on inputs with 2 spatial dimensions, a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).\n",
      "  warnings.warn(warn_msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 [0/60000 (0%)]\t training loss: 2.310609\n",
      "epoch: 1 [320/60000 (1%)]\t training loss: 1.924133\n",
      "epoch: 1 [640/60000 (1%)]\t training loss: 1.313337\n",
      "epoch: 1 [960/60000 (2%)]\t training loss: 0.796470\n",
      "epoch: 1 [1280/60000 (2%)]\t training loss: 0.819801\n",
      "epoch: 1 [1600/60000 (3%)]\t training loss: 0.678459\n",
      "epoch: 1 [1920/60000 (3%)]\t training loss: 0.477066\n",
      "epoch: 1 [2240/60000 (4%)]\t training loss: 0.527125\n",
      "epoch: 1 [2560/60000 (4%)]\t training loss: 0.469366\n",
      "epoch: 1 [2880/60000 (5%)]\t training loss: 0.239070\n",
      "epoch: 1 [3200/60000 (5%)]\t training loss: 0.523908\n",
      "epoch: 1 [3520/60000 (6%)]\t training loss: 0.267831\n",
      "epoch: 1 [3840/60000 (6%)]\t training loss: 0.464627\n",
      "epoch: 1 [4160/60000 (7%)]\t training loss: 0.413977\n",
      "epoch: 1 [4480/60000 (7%)]\t training loss: 0.324858\n",
      "epoch: 1 [4800/60000 (8%)]\t training loss: 0.495965\n",
      "epoch: 1 [5120/60000 (9%)]\t training loss: 0.153463\n",
      "epoch: 1 [5440/60000 (9%)]\t training loss: 0.377620\n",
      "epoch: 1 [5760/60000 (10%)]\t training loss: 0.097269\n",
      "epoch: 1 [6080/60000 (10%)]\t training loss: 0.181864\n",
      "epoch: 1 [6400/60000 (11%)]\t training loss: 0.295800\n",
      "epoch: 1 [6720/60000 (11%)]\t training loss: 0.082916\n",
      "epoch: 1 [7040/60000 (12%)]\t training loss: 0.353389\n",
      "epoch: 1 [7360/60000 (12%)]\t training loss: 0.040657\n",
      "epoch: 1 [7680/60000 (13%)]\t training loss: 0.365705\n",
      "epoch: 1 [8000/60000 (13%)]\t training loss: 0.231380\n",
      "epoch: 1 [8320/60000 (14%)]\t training loss: 0.321043\n",
      "epoch: 1 [8640/60000 (14%)]\t training loss: 0.211378\n",
      "epoch: 1 [8960/60000 (15%)]\t training loss: 0.054053\n",
      "epoch: 1 [9280/60000 (15%)]\t training loss: 0.337620\n",
      "epoch: 1 [9600/60000 (16%)]\t training loss: 0.216428\n",
      "epoch: 1 [9920/60000 (17%)]\t training loss: 0.174468\n",
      "epoch: 1 [10240/60000 (17%)]\t training loss: 0.213010\n",
      "epoch: 1 [10560/60000 (18%)]\t training loss: 0.357055\n",
      "epoch: 1 [10880/60000 (18%)]\t training loss: 0.070653\n",
      "epoch: 1 [11200/60000 (19%)]\t training loss: 0.055425\n",
      "epoch: 1 [11520/60000 (19%)]\t training loss: 0.230391\n",
      "epoch: 1 [11840/60000 (20%)]\t training loss: 0.048202\n",
      "epoch: 1 [12160/60000 (20%)]\t training loss: 0.403469\n",
      "epoch: 1 [12480/60000 (21%)]\t training loss: 0.025848\n",
      "epoch: 1 [12800/60000 (21%)]\t training loss: 0.104852\n",
      "epoch: 1 [13120/60000 (22%)]\t training loss: 0.101376\n",
      "epoch: 1 [13440/60000 (22%)]\t training loss: 0.178021\n",
      "epoch: 1 [13760/60000 (23%)]\t training loss: 0.070927\n",
      "epoch: 1 [14080/60000 (23%)]\t training loss: 0.048668\n",
      "epoch: 1 [14400/60000 (24%)]\t training loss: 0.090581\n",
      "epoch: 1 [14720/60000 (25%)]\t training loss: 0.165846\n",
      "epoch: 1 [15040/60000 (25%)]\t training loss: 0.324705\n",
      "epoch: 1 [15360/60000 (26%)]\t training loss: 0.077008\n",
      "epoch: 1 [15680/60000 (26%)]\t training loss: 0.128724\n",
      "epoch: 1 [16000/60000 (27%)]\t training loss: 0.084200\n",
      "epoch: 1 [16320/60000 (27%)]\t training loss: 0.158558\n",
      "epoch: 1 [16640/60000 (28%)]\t training loss: 0.040566\n",
      "epoch: 1 [16960/60000 (28%)]\t training loss: 0.287925\n",
      "epoch: 1 [17280/60000 (29%)]\t training loss: 0.057604\n",
      "epoch: 1 [17600/60000 (29%)]\t training loss: 0.031712\n",
      "epoch: 1 [17920/60000 (30%)]\t training loss: 0.046777\n",
      "epoch: 1 [18240/60000 (30%)]\t training loss: 0.025740\n",
      "epoch: 1 [18560/60000 (31%)]\t training loss: 0.065141\n",
      "epoch: 1 [18880/60000 (31%)]\t training loss: 0.220227\n",
      "epoch: 1 [19200/60000 (32%)]\t training loss: 0.102033\n",
      "epoch: 1 [19520/60000 (33%)]\t training loss: 0.068922\n",
      "epoch: 1 [19840/60000 (33%)]\t training loss: 0.075009\n",
      "epoch: 1 [20160/60000 (34%)]\t training loss: 0.111961\n",
      "epoch: 1 [20480/60000 (34%)]\t training loss: 0.056665\n",
      "epoch: 1 [20800/60000 (35%)]\t training loss: 0.234118\n",
      "epoch: 1 [21120/60000 (35%)]\t training loss: 0.303591\n",
      "epoch: 1 [21440/60000 (36%)]\t training loss: 0.217894\n",
      "epoch: 1 [21760/60000 (36%)]\t training loss: 0.091958\n",
      "epoch: 1 [22080/60000 (37%)]\t training loss: 0.079766\n",
      "epoch: 1 [22400/60000 (37%)]\t training loss: 0.151057\n",
      "epoch: 1 [22720/60000 (38%)]\t training loss: 0.262819\n",
      "epoch: 1 [23040/60000 (38%)]\t training loss: 0.061579\n",
      "epoch: 1 [23360/60000 (39%)]\t training loss: 0.135419\n",
      "epoch: 1 [23680/60000 (39%)]\t training loss: 0.285402\n",
      "epoch: 1 [24000/60000 (40%)]\t training loss: 0.148586\n",
      "epoch: 1 [24320/60000 (41%)]\t training loss: 0.066807\n",
      "epoch: 1 [24640/60000 (41%)]\t training loss: 0.126747\n",
      "epoch: 1 [24960/60000 (42%)]\t training loss: 0.026213\n",
      "epoch: 1 [25280/60000 (42%)]\t training loss: 0.024073\n",
      "epoch: 1 [25600/60000 (43%)]\t training loss: 0.600073\n",
      "epoch: 1 [25920/60000 (43%)]\t training loss: 0.497171\n",
      "epoch: 1 [26240/60000 (44%)]\t training loss: 0.118145\n",
      "epoch: 1 [26560/60000 (44%)]\t training loss: 0.040906\n",
      "epoch: 1 [26880/60000 (45%)]\t training loss: 0.154461\n",
      "epoch: 1 [27200/60000 (45%)]\t training loss: 0.013285\n",
      "epoch: 1 [27520/60000 (46%)]\t training loss: 0.027993\n",
      "epoch: 1 [27840/60000 (46%)]\t training loss: 0.127872\n",
      "epoch: 1 [28160/60000 (47%)]\t training loss: 0.040515\n",
      "epoch: 1 [28480/60000 (47%)]\t training loss: 0.402136\n",
      "epoch: 1 [28800/60000 (48%)]\t training loss: 0.138493\n",
      "epoch: 1 [29120/60000 (49%)]\t training loss: 0.207102\n",
      "epoch: 1 [29440/60000 (49%)]\t training loss: 0.012119\n",
      "epoch: 1 [29760/60000 (50%)]\t training loss: 0.054411\n",
      "epoch: 1 [30080/60000 (50%)]\t training loss: 0.546068\n",
      "epoch: 1 [30400/60000 (51%)]\t training loss: 0.144434\n",
      "epoch: 1 [30720/60000 (51%)]\t training loss: 0.225871\n",
      "epoch: 1 [31040/60000 (52%)]\t training loss: 0.069757\n",
      "epoch: 1 [31360/60000 (52%)]\t training loss: 0.068556\n",
      "epoch: 1 [31680/60000 (53%)]\t training loss: 0.034509\n",
      "epoch: 1 [32000/60000 (53%)]\t training loss: 0.015917\n",
      "epoch: 1 [32320/60000 (54%)]\t training loss: 0.107069\n",
      "epoch: 1 [32640/60000 (54%)]\t training loss: 0.110502\n",
      "epoch: 1 [32960/60000 (55%)]\t training loss: 0.077687\n",
      "epoch: 1 [33280/60000 (55%)]\t training loss: 0.276489\n",
      "epoch: 1 [33600/60000 (56%)]\t training loss: 0.186835\n",
      "epoch: 1 [33920/60000 (57%)]\t training loss: 0.017603\n",
      "epoch: 1 [34240/60000 (57%)]\t training loss: 0.025609\n",
      "epoch: 1 [34560/60000 (58%)]\t training loss: 0.003528\n",
      "epoch: 1 [34880/60000 (58%)]\t training loss: 0.123900\n",
      "epoch: 1 [35200/60000 (59%)]\t training loss: 0.229518\n",
      "epoch: 1 [35520/60000 (59%)]\t training loss: 0.131734\n",
      "epoch: 1 [35840/60000 (60%)]\t training loss: 0.015819\n",
      "epoch: 1 [36160/60000 (60%)]\t training loss: 0.165748\n",
      "epoch: 1 [36480/60000 (61%)]\t training loss: 0.117925\n",
      "epoch: 1 [36800/60000 (61%)]\t training loss: 0.140046\n",
      "epoch: 1 [37120/60000 (62%)]\t training loss: 0.043650\n",
      "epoch: 1 [37440/60000 (62%)]\t training loss: 0.216800\n",
      "epoch: 1 [37760/60000 (63%)]\t training loss: 0.032769\n",
      "epoch: 1 [38080/60000 (63%)]\t training loss: 0.113341\n",
      "epoch: 1 [38400/60000 (64%)]\t training loss: 0.074898\n",
      "epoch: 1 [38720/60000 (65%)]\t training loss: 0.488738\n",
      "epoch: 1 [39040/60000 (65%)]\t training loss: 0.048389\n",
      "epoch: 1 [39360/60000 (66%)]\t training loss: 0.176396\n",
      "epoch: 1 [39680/60000 (66%)]\t training loss: 0.021329\n",
      "epoch: 1 [40000/60000 (67%)]\t training loss: 0.106253\n",
      "epoch: 1 [40320/60000 (67%)]\t training loss: 0.065340\n",
      "epoch: 1 [40640/60000 (68%)]\t training loss: 0.202225\n",
      "epoch: 1 [40960/60000 (68%)]\t training loss: 0.194045\n",
      "epoch: 1 [41280/60000 (69%)]\t training loss: 0.213245\n",
      "epoch: 1 [41600/60000 (69%)]\t training loss: 0.155805\n",
      "epoch: 1 [41920/60000 (70%)]\t training loss: 0.040492\n",
      "epoch: 1 [42240/60000 (70%)]\t training loss: 0.011083\n",
      "epoch: 1 [42560/60000 (71%)]\t training loss: 0.037179\n",
      "epoch: 1 [42880/60000 (71%)]\t training loss: 0.100338\n",
      "epoch: 1 [43200/60000 (72%)]\t training loss: 0.045283\n",
      "epoch: 1 [43520/60000 (73%)]\t training loss: 0.226100\n",
      "epoch: 1 [43840/60000 (73%)]\t training loss: 0.040024\n",
      "epoch: 1 [44160/60000 (74%)]\t training loss: 0.116498\n",
      "epoch: 1 [44480/60000 (74%)]\t training loss: 0.061347\n",
      "epoch: 1 [44800/60000 (75%)]\t training loss: 0.250636\n",
      "epoch: 1 [45120/60000 (75%)]\t training loss: 0.015089\n",
      "epoch: 1 [45440/60000 (76%)]\t training loss: 0.179538\n",
      "epoch: 1 [45760/60000 (76%)]\t training loss: 0.005873\n",
      "epoch: 1 [46080/60000 (77%)]\t training loss: 0.246703\n",
      "epoch: 1 [46400/60000 (77%)]\t training loss: 0.059748\n",
      "epoch: 1 [46720/60000 (78%)]\t training loss: 0.086515\n",
      "epoch: 1 [47040/60000 (78%)]\t training loss: 0.056783\n",
      "epoch: 1 [47360/60000 (79%)]\t training loss: 0.026905\n",
      "epoch: 1 [47680/60000 (79%)]\t training loss: 0.229662\n",
      "epoch: 1 [48000/60000 (80%)]\t training loss: 0.219784\n",
      "epoch: 1 [48320/60000 (81%)]\t training loss: 0.173513\n",
      "epoch: 1 [48640/60000 (81%)]\t training loss: 0.020131\n",
      "epoch: 1 [48960/60000 (82%)]\t training loss: 0.092622\n",
      "epoch: 1 [49280/60000 (82%)]\t training loss: 0.143768\n",
      "epoch: 1 [49600/60000 (83%)]\t training loss: 0.227249\n",
      "epoch: 1 [49920/60000 (83%)]\t training loss: 0.129910\n",
      "epoch: 1 [50240/60000 (84%)]\t training loss: 0.071571\n",
      "epoch: 1 [50560/60000 (84%)]\t training loss: 0.003748\n",
      "epoch: 1 [50880/60000 (85%)]\t training loss: 0.004240\n",
      "epoch: 1 [51200/60000 (85%)]\t training loss: 0.232011\n",
      "epoch: 1 [51520/60000 (86%)]\t training loss: 0.028602\n",
      "epoch: 1 [51840/60000 (86%)]\t training loss: 0.013641\n",
      "epoch: 1 [52160/60000 (87%)]\t training loss: 0.010407\n",
      "epoch: 1 [52480/60000 (87%)]\t training loss: 0.014092\n",
      "epoch: 1 [52800/60000 (88%)]\t training loss: 0.041886\n",
      "epoch: 1 [53120/60000 (89%)]\t training loss: 0.029681\n",
      "epoch: 1 [53440/60000 (89%)]\t training loss: 0.032927\n",
      "epoch: 1 [53760/60000 (90%)]\t training loss: 0.044958\n",
      "epoch: 1 [54080/60000 (90%)]\t training loss: 0.014901\n",
      "epoch: 1 [54400/60000 (91%)]\t training loss: 0.136652\n",
      "epoch: 1 [54720/60000 (91%)]\t training loss: 0.011173\n",
      "epoch: 1 [55040/60000 (92%)]\t training loss: 0.002030\n",
      "epoch: 1 [55360/60000 (92%)]\t training loss: 0.066842\n",
      "epoch: 1 [55680/60000 (93%)]\t training loss: 0.048196\n",
      "epoch: 1 [56000/60000 (93%)]\t training loss: 0.052259\n",
      "epoch: 1 [56320/60000 (94%)]\t training loss: 0.161982\n",
      "epoch: 1 [56640/60000 (94%)]\t training loss: 0.054209\n",
      "epoch: 1 [56960/60000 (95%)]\t training loss: 0.028114\n",
      "epoch: 1 [57280/60000 (95%)]\t training loss: 0.107114\n",
      "epoch: 1 [57600/60000 (96%)]\t training loss: 0.003919\n",
      "epoch: 1 [57920/60000 (97%)]\t training loss: 0.093449\n",
      "epoch: 1 [58240/60000 (97%)]\t training loss: 0.043330\n",
      "epoch: 1 [58560/60000 (98%)]\t training loss: 0.005689\n",
      "epoch: 1 [58880/60000 (98%)]\t training loss: 0.087195\n",
      "epoch: 1 [59200/60000 (99%)]\t training loss: 0.009148\n",
      "epoch: 1 [59520/60000 (99%)]\t training loss: 0.035150\n",
      "epoch: 1 [59840/60000 (100%)]\t training loss: 0.024602\n",
      "\n",
      "Test dataset: Overall Loss: 0.0478, Overall Accuracy: 9845/10000 (98%)\n",
      "\n",
      "epoch: 2 [0/60000 (0%)]\t training loss: 0.022102\n",
      "epoch: 2 [320/60000 (1%)]\t training loss: 0.040999\n",
      "epoch: 2 [640/60000 (1%)]\t training loss: 0.024328\n",
      "epoch: 2 [960/60000 (2%)]\t training loss: 0.130329\n",
      "epoch: 2 [1280/60000 (2%)]\t training loss: 0.016056\n",
      "epoch: 2 [1600/60000 (3%)]\t training loss: 0.159618\n",
      "epoch: 2 [1920/60000 (3%)]\t training loss: 0.243704\n",
      "epoch: 2 [2240/60000 (4%)]\t training loss: 0.003967\n",
      "epoch: 2 [2560/60000 (4%)]\t training loss: 0.172185\n",
      "epoch: 2 [2880/60000 (5%)]\t training loss: 0.262268\n",
      "epoch: 2 [3200/60000 (5%)]\t training loss: 0.090865\n",
      "epoch: 2 [3520/60000 (6%)]\t training loss: 0.006844\n",
      "epoch: 2 [3840/60000 (6%)]\t training loss: 0.002883\n",
      "epoch: 2 [4160/60000 (7%)]\t training loss: 0.016303\n",
      "epoch: 2 [4480/60000 (7%)]\t training loss: 0.005249\n",
      "epoch: 2 [4800/60000 (8%)]\t training loss: 0.080174\n",
      "epoch: 2 [5120/60000 (9%)]\t training loss: 0.051679\n",
      "epoch: 2 [5440/60000 (9%)]\t training loss: 0.035300\n",
      "epoch: 2 [5760/60000 (10%)]\t training loss: 0.053597\n",
      "epoch: 2 [6080/60000 (10%)]\t training loss: 0.002861\n",
      "epoch: 2 [6400/60000 (11%)]\t training loss: 0.020764\n",
      "epoch: 2 [6720/60000 (11%)]\t training loss: 0.019133\n",
      "epoch: 2 [7040/60000 (12%)]\t training loss: 0.087477\n",
      "epoch: 2 [7360/60000 (12%)]\t training loss: 0.012886\n",
      "epoch: 2 [7680/60000 (13%)]\t training loss: 0.009266\n",
      "epoch: 2 [8000/60000 (13%)]\t training loss: 0.116196\n",
      "epoch: 2 [8320/60000 (14%)]\t training loss: 0.271565\n",
      "epoch: 2 [8640/60000 (14%)]\t training loss: 0.023257\n",
      "epoch: 2 [8960/60000 (15%)]\t training loss: 0.098916\n",
      "epoch: 2 [9280/60000 (15%)]\t training loss: 0.177419\n",
      "epoch: 2 [9600/60000 (16%)]\t training loss: 0.006227\n",
      "epoch: 2 [9920/60000 (17%)]\t training loss: 0.196180\n",
      "epoch: 2 [10240/60000 (17%)]\t training loss: 0.005000\n",
      "epoch: 2 [10560/60000 (18%)]\t training loss: 0.030835\n",
      "epoch: 2 [10880/60000 (18%)]\t training loss: 0.204652\n",
      "epoch: 2 [11200/60000 (19%)]\t training loss: 0.029607\n",
      "epoch: 2 [11520/60000 (19%)]\t training loss: 0.004772\n",
      "epoch: 2 [11840/60000 (20%)]\t training loss: 0.013179\n",
      "epoch: 2 [12160/60000 (20%)]\t training loss: 0.031670\n",
      "epoch: 2 [12480/60000 (21%)]\t training loss: 0.171905\n",
      "epoch: 2 [12800/60000 (21%)]\t training loss: 0.027294\n",
      "epoch: 2 [13120/60000 (22%)]\t training loss: 0.011103\n",
      "epoch: 2 [13440/60000 (22%)]\t training loss: 0.001325\n",
      "epoch: 2 [13760/60000 (23%)]\t training loss: 0.028688\n",
      "epoch: 2 [14080/60000 (23%)]\t training loss: 0.024691\n",
      "epoch: 2 [14400/60000 (24%)]\t training loss: 0.103728\n",
      "epoch: 2 [14720/60000 (25%)]\t training loss: 0.010241\n",
      "epoch: 2 [15040/60000 (25%)]\t training loss: 0.004725\n",
      "epoch: 2 [15360/60000 (26%)]\t training loss: 0.014735\n",
      "epoch: 2 [15680/60000 (26%)]\t training loss: 0.090812\n",
      "epoch: 2 [16000/60000 (27%)]\t training loss: 0.092023\n",
      "epoch: 2 [16320/60000 (27%)]\t training loss: 0.227561\n",
      "epoch: 2 [16640/60000 (28%)]\t training loss: 0.049542\n",
      "epoch: 2 [16960/60000 (28%)]\t training loss: 0.017045\n",
      "epoch: 2 [17280/60000 (29%)]\t training loss: 0.005333\n",
      "epoch: 2 [17600/60000 (29%)]\t training loss: 0.033627\n",
      "epoch: 2 [17920/60000 (30%)]\t training loss: 0.117465\n",
      "epoch: 2 [18240/60000 (30%)]\t training loss: 0.080745\n",
      "epoch: 2 [18560/60000 (31%)]\t training loss: 0.054088\n",
      "epoch: 2 [18880/60000 (31%)]\t training loss: 0.012624\n",
      "epoch: 2 [19200/60000 (32%)]\t training loss: 0.040346\n",
      "epoch: 2 [19520/60000 (33%)]\t training loss: 0.110093\n",
      "epoch: 2 [19840/60000 (33%)]\t training loss: 0.005584\n",
      "epoch: 2 [20160/60000 (34%)]\t training loss: 0.006840\n",
      "epoch: 2 [20480/60000 (34%)]\t training loss: 0.004643\n",
      "epoch: 2 [20800/60000 (35%)]\t training loss: 0.034536\n",
      "epoch: 2 [21120/60000 (35%)]\t training loss: 0.083564\n",
      "epoch: 2 [21440/60000 (36%)]\t training loss: 0.035088\n",
      "epoch: 2 [21760/60000 (36%)]\t training loss: 0.120426\n",
      "epoch: 2 [22080/60000 (37%)]\t training loss: 0.061219\n",
      "epoch: 2 [22400/60000 (37%)]\t training loss: 0.308915\n",
      "epoch: 2 [22720/60000 (38%)]\t training loss: 0.040703\n",
      "epoch: 2 [23040/60000 (38%)]\t training loss: 0.026090\n",
      "epoch: 2 [23360/60000 (39%)]\t training loss: 0.010202\n",
      "epoch: 2 [23680/60000 (39%)]\t training loss: 0.045219\n",
      "epoch: 2 [24000/60000 (40%)]\t training loss: 0.016603\n",
      "epoch: 2 [24320/60000 (41%)]\t training loss: 0.603639\n",
      "epoch: 2 [24640/60000 (41%)]\t training loss: 0.008654\n",
      "epoch: 2 [24960/60000 (42%)]\t training loss: 0.019752\n",
      "epoch: 2 [25280/60000 (42%)]\t training loss: 0.023641\n",
      "epoch: 2 [25600/60000 (43%)]\t training loss: 0.015268\n",
      "epoch: 2 [25920/60000 (43%)]\t training loss: 0.102580\n",
      "epoch: 2 [26240/60000 (44%)]\t training loss: 0.017379\n",
      "epoch: 2 [26560/60000 (44%)]\t training loss: 0.004729\n",
      "epoch: 2 [26880/60000 (45%)]\t training loss: 0.008352\n",
      "epoch: 2 [27200/60000 (45%)]\t training loss: 0.015206\n",
      "epoch: 2 [27520/60000 (46%)]\t training loss: 0.151630\n",
      "epoch: 2 [27840/60000 (46%)]\t training loss: 0.014150\n",
      "epoch: 2 [28160/60000 (47%)]\t training loss: 0.001390\n",
      "epoch: 2 [28480/60000 (47%)]\t training loss: 0.214062\n",
      "epoch: 2 [28800/60000 (48%)]\t training loss: 0.167573\n",
      "epoch: 2 [29120/60000 (49%)]\t training loss: 0.066168\n",
      "epoch: 2 [29440/60000 (49%)]\t training loss: 0.002024\n",
      "epoch: 2 [29760/60000 (50%)]\t training loss: 0.026353\n",
      "epoch: 2 [30080/60000 (50%)]\t training loss: 0.199319\n",
      "epoch: 2 [30400/60000 (51%)]\t training loss: 0.181320\n",
      "epoch: 2 [30720/60000 (51%)]\t training loss: 0.116269\n",
      "epoch: 2 [31040/60000 (52%)]\t training loss: 0.007827\n",
      "epoch: 2 [31360/60000 (52%)]\t training loss: 0.031388\n",
      "epoch: 2 [31680/60000 (53%)]\t training loss: 0.037498\n",
      "epoch: 2 [32000/60000 (53%)]\t training loss: 0.092244\n",
      "epoch: 2 [32320/60000 (54%)]\t training loss: 0.080851\n",
      "epoch: 2 [32640/60000 (54%)]\t training loss: 0.040496\n",
      "epoch: 2 [32960/60000 (55%)]\t training loss: 0.013496\n",
      "epoch: 2 [33280/60000 (55%)]\t training loss: 0.092531\n",
      "epoch: 2 [33600/60000 (56%)]\t training loss: 0.017307\n",
      "epoch: 2 [33920/60000 (57%)]\t training loss: 0.134077\n",
      "epoch: 2 [34240/60000 (57%)]\t training loss: 0.175124\n",
      "epoch: 2 [34560/60000 (58%)]\t training loss: 0.220506\n",
      "epoch: 2 [34880/60000 (58%)]\t training loss: 0.018582\n",
      "epoch: 2 [35200/60000 (59%)]\t training loss: 0.073093\n",
      "epoch: 2 [35520/60000 (59%)]\t training loss: 0.030787\n",
      "epoch: 2 [35840/60000 (60%)]\t training loss: 0.010614\n",
      "epoch: 2 [36160/60000 (60%)]\t training loss: 0.034699\n",
      "epoch: 2 [36480/60000 (61%)]\t training loss: 0.034723\n",
      "epoch: 2 [36800/60000 (61%)]\t training loss: 0.009217\n",
      "epoch: 2 [37120/60000 (62%)]\t training loss: 0.005739\n",
      "epoch: 2 [37440/60000 (62%)]\t training loss: 0.040345\n",
      "epoch: 2 [37760/60000 (63%)]\t training loss: 0.218793\n",
      "epoch: 2 [38080/60000 (63%)]\t training loss: 0.198174\n",
      "epoch: 2 [38400/60000 (64%)]\t training loss: 0.035814\n",
      "epoch: 2 [38720/60000 (65%)]\t training loss: 0.068463\n",
      "epoch: 2 [39040/60000 (65%)]\t training loss: 0.007639\n",
      "epoch: 2 [39360/60000 (66%)]\t training loss: 0.026163\n",
      "epoch: 2 [39680/60000 (66%)]\t training loss: 0.148529\n",
      "epoch: 2 [40000/60000 (67%)]\t training loss: 0.075740\n",
      "epoch: 2 [40320/60000 (67%)]\t training loss: 0.047596\n",
      "epoch: 2 [40640/60000 (68%)]\t training loss: 0.013017\n",
      "epoch: 2 [40960/60000 (68%)]\t training loss: 0.063155\n",
      "epoch: 2 [41280/60000 (69%)]\t training loss: 0.022223\n",
      "epoch: 2 [41600/60000 (69%)]\t training loss: 0.024700\n",
      "epoch: 2 [41920/60000 (70%)]\t training loss: 0.001801\n",
      "epoch: 2 [42240/60000 (70%)]\t training loss: 0.002220\n",
      "epoch: 2 [42560/60000 (71%)]\t training loss: 0.048338\n",
      "epoch: 2 [42880/60000 (71%)]\t training loss: 0.046190\n",
      "epoch: 2 [43200/60000 (72%)]\t training loss: 0.004485\n",
      "epoch: 2 [43520/60000 (73%)]\t training loss: 0.107624\n",
      "epoch: 2 [43840/60000 (73%)]\t training loss: 0.078155\n",
      "epoch: 2 [44160/60000 (74%)]\t training loss: 0.156051\n",
      "epoch: 2 [44480/60000 (74%)]\t training loss: 0.001538\n",
      "epoch: 2 [44800/60000 (75%)]\t training loss: 0.300366\n",
      "epoch: 2 [45120/60000 (75%)]\t training loss: 0.041359\n",
      "epoch: 2 [45440/60000 (76%)]\t training loss: 0.390815\n",
      "epoch: 2 [45760/60000 (76%)]\t training loss: 0.024482\n",
      "epoch: 2 [46080/60000 (77%)]\t training loss: 0.017786\n",
      "epoch: 2 [46400/60000 (77%)]\t training loss: 0.028327\n",
      "epoch: 2 [46720/60000 (78%)]\t training loss: 0.001570\n",
      "epoch: 2 [47040/60000 (78%)]\t training loss: 0.066162\n",
      "epoch: 2 [47360/60000 (79%)]\t training loss: 0.154574\n",
      "epoch: 2 [47680/60000 (79%)]\t training loss: 0.091009\n",
      "epoch: 2 [48000/60000 (80%)]\t training loss: 0.042673\n",
      "epoch: 2 [48320/60000 (81%)]\t training loss: 0.106071\n",
      "epoch: 2 [48640/60000 (81%)]\t training loss: 0.004252\n",
      "epoch: 2 [48960/60000 (82%)]\t training loss: 0.014724\n",
      "epoch: 2 [49280/60000 (82%)]\t training loss: 0.010253\n",
      "epoch: 2 [49600/60000 (83%)]\t training loss: 0.010249\n",
      "epoch: 2 [49920/60000 (83%)]\t training loss: 0.362407\n",
      "epoch: 2 [50240/60000 (84%)]\t training loss: 0.115282\n",
      "epoch: 2 [50560/60000 (84%)]\t training loss: 0.106830\n",
      "epoch: 2 [50880/60000 (85%)]\t training loss: 0.105281\n",
      "epoch: 2 [51200/60000 (85%)]\t training loss: 0.119776\n",
      "epoch: 2 [51520/60000 (86%)]\t training loss: 0.067086\n",
      "epoch: 2 [51840/60000 (86%)]\t training loss: 0.023813\n",
      "epoch: 2 [52160/60000 (87%)]\t training loss: 0.047860\n",
      "epoch: 2 [52480/60000 (87%)]\t training loss: 0.005951\n",
      "epoch: 2 [52800/60000 (88%)]\t training loss: 0.003520\n",
      "epoch: 2 [53120/60000 (89%)]\t training loss: 0.018581\n",
      "epoch: 2 [53440/60000 (89%)]\t training loss: 0.055058\n",
      "epoch: 2 [53760/60000 (90%)]\t training loss: 0.060451\n",
      "epoch: 2 [54080/60000 (90%)]\t training loss: 0.047163\n",
      "epoch: 2 [54400/60000 (91%)]\t training loss: 0.003440\n",
      "epoch: 2 [54720/60000 (91%)]\t training loss: 0.153727\n",
      "epoch: 2 [55040/60000 (92%)]\t training loss: 0.126677\n",
      "epoch: 2 [55360/60000 (92%)]\t training loss: 0.031780\n",
      "epoch: 2 [55680/60000 (93%)]\t training loss: 0.010447\n",
      "epoch: 2 [56000/60000 (93%)]\t training loss: 0.036427\n",
      "epoch: 2 [56320/60000 (94%)]\t training loss: 0.006692\n",
      "epoch: 2 [56640/60000 (94%)]\t training loss: 0.001685\n",
      "epoch: 2 [56960/60000 (95%)]\t training loss: 0.124056\n",
      "epoch: 2 [57280/60000 (95%)]\t training loss: 0.169876\n",
      "epoch: 2 [57600/60000 (96%)]\t training loss: 0.228522\n",
      "epoch: 2 [57920/60000 (97%)]\t training loss: 0.115276\n",
      "epoch: 2 [58240/60000 (97%)]\t training loss: 0.042843\n",
      "epoch: 2 [58560/60000 (98%)]\t training loss: 0.003510\n",
      "epoch: 2 [58880/60000 (98%)]\t training loss: 0.008058\n",
      "epoch: 2 [59200/60000 (99%)]\t training loss: 0.006398\n",
      "epoch: 2 [59520/60000 (99%)]\t training loss: 0.016740\n",
      "epoch: 2 [59840/60000 (100%)]\t training loss: 0.039168\n",
      "\n",
      "Test dataset: Overall Loss: 0.0385, Overall Accuracy: 9872/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, 3):\n",
    "    train(model, device, train_dataloader, optimizer, epoch)\n",
    "    test(model, device, test_dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## run inference on trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAb/ElEQVR4nO3df2xV9f3H8dflR6+o7e1KbW+v/LDgDxYRVCa1QxFHQ6mbAyQGnX/gphBcMVOmLl3kh0rWyZJpNB3+2AIzE1QSgfkjXbTaNmrBUCFEt1VKOinSlknCvaXQwtrP9w++3u1KC57LvX3flucj+STcc877nnePp/fluff0c33OOScAAPrZEOsGAADnJgIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJoZZN/BNPT09OnDggNLT0+Xz+azbAQB45JxTe3u7QqGQhgzp+zon5QLowIEDGj16tHUbAICz1NzcrFGjRvW5PuXegktPT7duAQCQAGd6PU9aAFVUVOiSSy7Reeedp4KCAn388cffqo633QBgcDjT63lSAujVV1/VsmXLtHLlSn3yySeaPHmyiouLdfDgwWTsDgAwELkkmDp1qistLY0+7u7udqFQyJWXl5+xNhwOO0kMBoPBGOAjHA6f9vU+4VdAx48fV319vYqKiqLLhgwZoqKiItXV1Z2yfVdXlyKRSMwAAAx+CQ+gr776St3d3crNzY1Znpubq9bW1lO2Ly8vVyAQiA7ugAOAc4P5XXBlZWUKh8PR0dzcbN0SAKAfJPzvgLKzszV06FC1tbXFLG9ra1MwGDxle7/fL7/fn+g2AAApLuFXQGlpaZoyZYqqqqqiy3p6elRVVaXCwsJE7w4AMEAlZSaEZcuWaeHChfre976nqVOn6umnn1ZHR4d++tOfJmN3AIABKCkBtGDBAv373//WihUr1NraqquvvlqVlZWn3JgAADh3+ZxzzrqJ/xWJRBQIBKzbAACcpXA4rIyMjD7Xm98FBwA4NxFAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwMQw6waAZMjNzY2r7o9//KPnmvT0dM819fX1nms2btzouWbXrl2eayTpP//5T1x1gBdcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDhc8456yb+VyQSUSAQsG4DSeL3+z3X3HLLLZ5rnn/+ec81kpSVleW5xufzea7pr1+7t99+O666Y8eOea5paWnxXLN8+XLPNe3t7Z5rYCMcDisjI6PP9VwBAQBMEEAAABMJD6BVq1bJ5/PFjAkTJiR6NwCAAS4pX0h35ZVX6t133/3vTobxvXcAgFhJSYZhw4YpGAwm46kBAINEUj4D2rNnj0KhkMaNG6e77rpL+/bt63Pbrq4uRSKRmAEAGPwSHkAFBQVav369KisrtXbtWjU1NenGG2/s89bJ8vJyBQKB6Bg9enSiWwIApKCEB1BJSYluv/12TZo0ScXFxXr77bd1+PBhvfbaa71uX1ZWpnA4HB3Nzc2JbgkAkIKSfndAZmamLr/8cjU2Nva63u/3x/XHiQCAgS3pfwd05MgR7d27V3l5ecneFQBgAEl4AD300EOqqanRv/71L3300UeaN2+ehg4dqjvvvDPRuwIADGAJfwtu//79uvPOO3Xo0CFddNFFuuGGG7Rt2zZddNFFid4VAGAAYzJS9KsXXnjBc83PfvazJHSSOKk8GWm8+utn+uyzzzzX1NbWeq559tlnPddI0ueffx5XHU5iMlIAQEoigAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABggslIEbe1a9d6rlm8eLHnmv48RT/55BPPNfFMjhnP92MtWLDAc82LL77ouUaSdu3a5bmmoqIirn31hwMHDsRVV1RU5LmGCUz/i8lIAQApiQACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABggtmwoRtuuCGuurfeestzzelmxu1LZ2en55pnnnnGc40krV692nNNe3t7XPsabEKhkOeaeGb4Xr58ueeazMxMzzWS9OWXX3quief36YsvvvBcMxAwGzYAICURQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwwWSkg8zll1/uueaDDz6Ia19ZWVmea3w+n+ea119/3XPN7bff7rkGA0M8k5GuWrUqrn3F8/L4m9/8xnPNihUrPNcMBExGCgBISQQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwwGekgc/3113uuiXcy0nh89NFHnmvmzZvnuebQoUOeazB4dXd3x1UXz8tjXV2d55ri4mLPNUePHvVc09+YjBQAkJIIIACACc8BVFtbq1tvvVWhUEg+n09btmyJWe+c04oVK5SXl6cRI0aoqKhIe/bsSVS/AIBBwnMAdXR0aPLkyaqoqOh1/Zo1a/TMM8/oueee0/bt23XBBReouLhYnZ2dZ90sAGDwGOa1oKSkRCUlJb2uc87p6aef1qOPPqo5c+ZIkl566SXl5uZqy5YtuuOOO86uWwDAoJHQz4CamprU2tqqoqKi6LJAIKCCgoI+7wzp6upSJBKJGQCAwS+hAdTa2ipJys3NjVmem5sbXfdN5eXlCgQC0TF69OhEtgQASFHmd8GVlZUpHA5HR3Nzs3VLAIB+kNAACgaDkqS2traY5W1tbdF13+T3+5WRkREzAACDX0IDKD8/X8FgUFVVVdFlkUhE27dvV2FhYSJ3BQAY4DzfBXfkyBE1NjZGHzc1NWnXrl3KysrSmDFj9MADD2j16tW67LLLlJ+fr+XLlysUCmnu3LmJ7BsAMMB5DqAdO3bo5ptvjj5etmyZJGnhwoVav369HnnkEXV0dGjx4sU6fPiwbrjhBlVWVuq8885LXNcAgAGPyUhTWHp6uueat956y3PN97//fc818Xr44Yc91zz11FNJ6ATnkueffz6uunvuucdzzeeff+65Jp6PKMLhsOea/sZkpACAlEQAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMOH56xjQf+bMmeO5Ztq0aUnopHc1NTWea5jZGhaOHDkSV53P5/Nc09nZ6bmmp6fHc81gwBUQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAE0xGmsKuueYazzXOuSR00rsnnnii3/YFWIjn9ykrK8tzTVpamueawYArIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACaYjBRx27Nnj3ULQMoZNWqU55oRI0YkoZPUxxUQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAE0xGCmDQu/baa61bQC+4AgIAmCCAAAAmPAdQbW2tbr31VoVCIfl8Pm3ZsiVm/d133y2fzxczZs+enah+AQCDhOcA6ujo0OTJk1VRUdHnNrNnz1ZLS0t0bNy48ayaBAAMPp5vQigpKVFJSclpt/H7/QoGg3E3BQAY/JLyGVB1dbVycnJ0xRVX6L777tOhQ4f63Larq0uRSCRmAAAGv4QH0OzZs/XSSy+pqqpKTz75pGpqalRSUqLu7u5ety8vL1cgEIiO0aNHJ7olAEAKSvjfAd1xxx3Rf1911VWaNGmSxo8fr+rqas2cOfOU7cvKyrRs2bLo40gkQggBwDkg6bdhjxs3TtnZ2WpsbOx1vd/vV0ZGRswAAAx+SQ+g/fv369ChQ8rLy0v2rgAAA4jnt+COHDkSczXT1NSkXbt2KSsrS1lZWXrsscc0f/58BYNB7d27V4888oguvfRSFRcXJ7RxAMDA5jmAduzYoZtvvjn6+OvPbxYuXKi1a9dq9+7d+vOf/6zDhw8rFApp1qxZeuKJJ+T3+xPXNQBgwPMcQDNmzJBzrs/1f/vb386qIZwdn89n3QKQVD/60Y8819x0001x7et0r3V92bRpk+ea/fv3e64ZDJgLDgBgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgIuFfyQ1b8czeG69gMOi55lyd9Re9mzhxoueaF1980XNNvL8X8dQdOHAgrn2di7gCAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYILJSFPYnj17rFs4rQULFniu2bFjRxI6QaKdf/75nmvuvfdezzWPP/6455oLL7zQc028Xn/9dc81K1euTEIngxNXQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEz4nHPOuon/FYlEFAgErNtICZmZmZ5rKioqPNfEM6moJB05csRzzcyZMz3X1NfXe67BSYsWLYqrbtWqVZ5rcnNz49pXf1i9enVcdU8++aTnmmPHjsW1r8EoHA4rIyOjz/VcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBZKSDTCgU8lzz4YcfxrWvMWPGeK7x+Xyea7788kvPNa+99prnmnjt3LnTc80111yThE5Ode+998ZVl56e7rmmv15K5s2b57nmr3/9axI6wZkwGSkAICURQAAAE54CqLy8XNddd53S09OVk5OjuXPnqqGhIWabzs5OlZaWauTIkbrwwgs1f/58tbW1JbRpAMDA5ymAampqVFpaqm3btumdd97RiRMnNGvWLHV0dES3efDBB/XGG29o06ZNqqmp0YEDB3TbbbclvHEAwMA2zMvGlZWVMY/Xr1+vnJwc1dfXa/r06QqHw/rTn/6kDRs26Ac/+IEkad26dfrud7+rbdu26frrr09c5wCAAe2sPgMKh8OSpKysLEknvzr5xIkTKioqim4zYcIEjRkzRnV1db0+R1dXlyKRSMwAAAx+cQdQT0+PHnjgAU2bNk0TJ06UJLW2tiotLU2ZmZkx2+bm5qq1tbXX5ykvL1cgEIiO0aNHx9sSAGAAiTuASktL9emnn+qVV145qwbKysoUDoejo7m5+ayeDwAwMHj6DOhrS5cu1Ztvvqna2lqNGjUqujwYDOr48eM6fPhwzFVQW1ubgsFgr8/l9/vl9/vjaQMAMIB5ugJyzmnp0qXavHmz3nvvPeXn58esnzJlioYPH66qqqrosoaGBu3bt0+FhYWJ6RgAMCh4ugIqLS3Vhg0btHXrVqWnp0c/1wkEAhoxYoQCgYDuueceLVu2TFlZWcrIyND999+vwsJC7oADAMTwFEBr166VJM2YMSNm+bp163T33XdLkp566ikNGTJE8+fPV1dXl4qLi/WHP/whIc0CAAYPJiOFbr755rjqXn31Vc81I0eO9FyTYqfoKeKZYHUw/kzfnBXl21ixYoXnmrfeestzzbFjxzzX4OwxGSkAICURQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwwGzbidu2113qu+fGPf+y55tFHH/Vc059SeTbseGaOlqTKykrPNbW1tZ5rPvvsM881GDiYDRsAkJIIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYDJS9Kthw4Z5rrn66qs916xZs8ZzjRTfBKvxTEb6wgsveK6Jx+rVq+OqC4fDCe4E5yImIwUApCQCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmmIwUAJAUTEYKAEhJBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAw4SmAysvLdd111yk9PV05OTmaO3euGhoaYraZMWOGfD5fzFiyZElCmwYADHyeAqimpkalpaXatm2b3nnnHZ04cUKzZs1SR0dHzHaLFi1SS0tLdKxZsyahTQMABr5hXjaurKyMebx+/Xrl5OSovr5e06dPjy4///zzFQwGE9MhAGBQOqvPgMLhsCQpKysrZvnLL7+s7OxsTZw4UWVlZTp69Gifz9HV1aVIJBIzAADnABen7u5u98Mf/tBNmzYtZvnzzz/vKisr3e7du91f/vIXd/HFF7t58+b1+TwrV650khgMBoMxyEY4HD5tjsQdQEuWLHFjx451zc3Np92uqqrKSXKNjY29ru/s7HThcDg6mpubzQ8ag8FgMM5+nCmAPH0G9LWlS5fqzTffVG1trUaNGnXabQsKCiRJjY2NGj9+/Cnr/X6//H5/PG0AAAYwTwHknNP999+vzZs3q7q6Wvn5+Wes2bVrlyQpLy8vrgYBAIOTpwAqLS3Vhg0btHXrVqWnp6u1tVWSFAgENGLECO3du1cbNmzQLbfcopEjR2r37t168MEHNX36dE2aNCkpPwAAYIDy8rmP+nifb926dc455/bt2+emT5/usrKynN/vd5deeql7+OGHz/g+4P8Kh8Pm71syGAwG4+zHmV77ff8fLCkjEokoEAhYtwEAOEvhcFgZGRl9rmcuOACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiZQLIOecdQsAgAQ40+t5ygVQe3u7dQsAgAQ40+u5z6XYJUdPT48OHDig9PR0+Xy+mHWRSESjR49Wc3OzMjIyjDq0x3E4ieNwEsfhJI7DSalwHJxzam9vVygU0pAhfV/nDOvHnr6VIUOGaNSoUafdJiMj45w+wb7GcTiJ43ASx+EkjsNJ1schEAiccZuUewsOAHBuIIAAACYGVAD5/X6tXLlSfr/fuhVTHIeTOA4ncRxO4jicNJCOQ8rdhAAAODcMqCsgAMDgQQABAEwQQAAAEwQQAMDEgAmgiooKXXLJJTrvvPNUUFCgjz/+2Lqlfrdq1Sr5fL6YMWHCBOu2kq62tla33nqrQqGQfD6ftmzZErPeOacVK1YoLy9PI0aMUFFRkfbs2WPTbBKd6Tjcfffdp5wfs2fPtmk2ScrLy3XdddcpPT1dOTk5mjt3rhoaGmK26ezsVGlpqUaOHKkLL7xQ8+fPV1tbm1HHyfFtjsOMGTNOOR+WLFli1HHvBkQAvfrqq1q2bJlWrlypTz75RJMnT1ZxcbEOHjxo3Vq/u/LKK9XS0hIdH3zwgXVLSdfR0aHJkyeroqKi1/Vr1qzRM888o+eee07bt2/XBRdcoOLiYnV2dvZzp8l1puMgSbNnz445PzZu3NiPHSZfTU2NSktLtW3bNr3zzjs6ceKEZs2apY6Ojug2Dz74oN544w1t2rRJNTU1OnDggG677TbDrhPv2xwHSVq0aFHM+bBmzRqjjvvgBoCpU6e60tLS6OPu7m4XCoVceXm5YVf9b+XKlW7y5MnWbZiS5DZv3hx93NPT44LBoPvd734XXXb48GHn9/vdxo0bDTrsH988Ds45t3DhQjdnzhyTfqwcPHjQSXI1NTXOuZP/7YcPH+42bdoU3eYf//iHk+Tq6uqs2ky6bx4H55y76aab3C9+8Qu7pr6FlL8COn78uOrr61VUVBRdNmTIEBUVFamurs6wMxt79uxRKBTSuHHjdNddd2nfvn3WLZlqampSa2trzPkRCARUUFBwTp4f1dXVysnJ0RVXXKH77rtPhw4dsm4pqcLhsCQpKytLklRfX68TJ07EnA8TJkzQmDFjBvX58M3j8LWXX35Z2dnZmjhxosrKynT06FGL9vqUcpORftNXX32l7u5u5ebmxizPzc3VP//5T6OubBQUFGj9+vW64oor1NLSoscee0w33nijPv30U6Wnp1u3Z6K1tVWSej0/vl53rpg9e7Zuu+025efna+/evfr1r3+tkpIS1dXVaejQodbtJVxPT48eeOABTZs2TRMnTpR08nxIS0tTZmZmzLaD+Xzo7ThI0k9+8hONHTtWoVBIu3fv1q9+9Ss1NDTo9ddfN+w2VsoHEP6rpKQk+u9JkyapoKBAY8eO1WuvvaZ77rnHsDOkgjvuuCP676uuukqTJk3S+PHjVV1drZkzZxp2lhylpaX69NNPz4nPQU+nr+OwePHi6L+vuuoq5eXlaebMmdq7d6/Gjx/f3232KuXfgsvOztbQoUNPuYulra1NwWDQqKvUkJmZqcsvv1yNjY3WrZj5+hzg/DjVuHHjlJ2dPSjPj6VLl+rNN9/U+++/H/P1LcFgUMePH9fhw4djth+s50Nfx6E3BQUFkpRS50PKB1BaWpqmTJmiqqqq6LKenh5VVVWpsLDQsDN7R44c0d69e5WXl2fdipn8/HwFg8GY8yMSiWj79u3n/Pmxf/9+HTp0aFCdH845LV26VJs3b9Z7772n/Pz8mPVTpkzR8OHDY86HhoYG7du3b1CdD2c6Dr3ZtWuXJKXW+WB9F8S38corrzi/3+/Wr1/v/v73v7vFixe7zMxM19raat1av/rlL3/pqqurXVNTk/vwww9dUVGRy87OdgcPHrRuLana29vdzp073c6dO50k9/vf/97t3LnTffHFF845537729+6zMxMt3XrVrd79243Z84cl5+f744dO2bceWKd7ji0t7e7hx56yNXV1bmmpib37rvvumuvvdZddtllrrOz07r1hLnvvvtcIBBw1dXVrqWlJTqOHj0a3WbJkiVuzJgx7r333nM7duxwhYWFrrCw0LDrxDvTcWhsbHSPP/6427Fjh2tqanJbt25148aNc9OnTzfuPNaACCDnnHv22WfdmDFjXFpamps6darbtm2bdUv9bsGCBS4vL8+lpaW5iy++2C1YsMA1NjZat5V077//vpN0yli4cKFz7uSt2MuXL3e5ubnO7/e7mTNnuoaGBtumk+B0x+Ho0aNu1qxZ7qKLLnLDhw93Y8eOdYsWLRp0/5PW288vya1bty66zbFjx9zPf/5z953vfMedf/75bt68ea6lpcWu6SQ403HYt2+fmz59usvKynJ+v99deuml7uGHH3bhcNi28W/g6xgAACZS/jMgAMDgRAABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwMT/Adn1QZE3qWxwAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_samples = enumerate(test_dataloader)\n",
    "b_i, (sample_data, sample_targets) = next(test_samples)\n",
    "\n",
    "plt.imshow(sample_data[0][0], cmap='gray', interpolation='none')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model prediction is : 0\n",
      "Ground truth is : 0\n"
     ]
    }
   ],
   "source": [
    "print(f\"Model prediction is : {model(sample_data).data.max(1)[1][0]}\")\n",
    "print(f\"Ground truth is : {sample_targets[0]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH_TO_MODEL = \"./convnet.pth\"\n",
    "torch.save(model.state_dict(), PATH_TO_MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Local)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
